{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Enhancing RAG Reasoning with Knowledge Graphs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "_Authored by: [Diego Carpintero](https://github.com/dcarpintero)_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Knowledge Graphs provide a method for modeling and storing interlinked information in a format that is both human- and machine-understandable. These graphs consist of *nodes* and *edges*, representing entities and their relationships. Unlike traditional databases, the inherent expressiveness of graphs allows for richer semantic understanding, while providing the flexibility to accommodate new entity types and relationships without being constrained by a fixed schema.\n",
    "\n",
    "By combining knowledge graphs with embeddings (vector search), we can leverage *multi-hop connectivity* and *contextual understanding of information* to enhance reasoning and explainability in LLMs. \n",
    "\n",
    "This notebook explores the practical implementation of this approach, demonstrating how to:\n",
    "- Build a knowledge graph in [Neo4j](https://neo4j.com/docs/) related to research publications using a synthetic dataset,\n",
    "- Project a subset of our data fields into a high-dimensional vector space using an [embedding model](https://python.langchain.com/v0.2/docs/integrations/text_embedding/),\n",
    "- Construct a vector index on those embeddings to enable similarity search, and\n",
    "- Extract insights from our graph using natural language by easily converting user queries into [cypher](https://neo4j.com/docs/cypher-manual/current/introduction/) statements with [LangChain](https://python.langchain.com/v0.2/docs/introduction/):"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "  <img src=\"https://raw.githubusercontent.com/dcarpintero/generative-ai-101/main/static/knowledge-graphs.png\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install neo4j langchain langchain_openai langchain-community python-dotenv --quiet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up a Neo4j instance\n",
    "\n",
    "We will create our Knowledge Graph using [Neo4j](https://neo4j.com/docs/), an open-source database management system that specializes in graph database technology.\n",
    "\n",
    "For a quick and easy setup, you can start a free instance on [Neo4j Aura](https://neo4j.com/product/auradb/).\n",
    "\n",
    "You might then set `NEO4J_URI`, `NEO4J_USERNAME`, and `NEO4J_PASSWORD` as environment variables using a `.env` file: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import dotenv\n",
    "dotenv.load_dotenv('.env', override=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Langchain provides the `Neo4jGraph` class to interact with Neo4j:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from langchain_community.graphs import Neo4jGraph\n",
    "\n",
    "graph = Neo4jGraph(\n",
    "    url=os.environ['NEO4J_URI'], \n",
    "    username=os.environ['NEO4J_USERNAME'],\n",
    "    password=os.environ['NEO4J_PASSWORD'],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading Dataset into a Graph\n",
    "\n",
    "The below example creates a connection with our `Neo4j` database and populates it with [synthetic data](https://github.com/dcarpintero/generative-ai-101/blob/main/dataset/synthetic_articles.csv) comprising research articles and their authors. \n",
    "\n",
    "The entities are: \n",
    "- *Researcher*\n",
    "- *Article*\n",
    "- *Topic*\n",
    "\n",
    "Whereas the relationships are:\n",
    "- *Researcher* --[PUBLISHED]--> *Article*\n",
    "- *Article* --[IN_TOPIC]--> *Topic*\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain_community.graphs import Neo4jGraph\n",
    "\n",
    "graph = Neo4jGraph()\n",
    "\n",
    "q_load_articles = \"\"\"\n",
    "LOAD CSV WITH HEADERS\n",
    "FROM 'https://raw.githubusercontent.com/dcarpintero/generative-ai-101/main/dataset/synthetic_articles.csv' \n",
    "AS row \n",
    "FIELDTERMINATOR ';'\n",
    "MERGE (a:Article {title:row.Title})\n",
    "SET a.abstract = row.Abstract,\n",
    "    a.publication_date = date(row.Publication_Date)\n",
    "FOREACH (researcher in split(row.Authors, ',') | \n",
    "    MERGE (p:Researcher {name:trim(researcher)})\n",
    "    MERGE (p)-[:PUBLISHED]->(a))\n",
    "FOREACH (topic in [row.Topic] | \n",
    "    MERGE (t:Topic {name:trim(topic)})\n",
    "    MERGE (a)-[:IN_TOPIC]->(t))\n",
    "\"\"\"\n",
    "\n",
    "graph.query(q_load_articles)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's check that the nodes and relationships have been initialized correctly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node properties:\n",
      "Article {title: STRING, abstract: STRING, publication_date: DATE, embedding: LIST}\n",
      "Researcher {name: STRING}\n",
      "Topic {name: STRING}\n",
      "Relationship properties:\n",
      "\n",
      "The relationships:\n",
      "(:Article)-[:IN_TOPIC]->(:Topic)\n",
      "(:Researcher)-[:PUBLISHED]->(:Article)\n"
     ]
    }
   ],
   "source": [
    "graph.refresh_schema()\n",
    "print(graph.get_schema)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our knowledge graph can be inspected in the Neo4j workspace:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p>\n",
    "  <img src=\"https://raw.githubusercontent.com/dcarpintero/generative-ai-101/main/static/kg_sample_00.png\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building a Vector Index\n",
    "\n",
    "Now we construct a vector index to efficiently search for relevant *articles* based on their *topic, title, and abstract*. This process involves calculating the embeddings for each article using these fields. At query time, the system finds the most similar articles to the user's input by employing a similarity metric, such as cosine distance.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.vectorstores import Neo4jVector\n",
    "from langchain_openai import OpenAIEmbeddings\n",
    "\n",
    "vector_index = Neo4jVector.from_existing_graph(\n",
    "    OpenAIEmbeddings(),\n",
    "    url=os.environ['NEO4J_URI'],\n",
    "    username=os.environ['NEO4J_USERNAME'],\n",
    "    password=os.environ['NEO4J_PASSWORD'],\n",
    "    index_name='articles',\n",
    "    node_label=\"Article\",\n",
    "    text_node_properties=['topic', 'title', 'abstract'],\n",
    "    embedding_node_property='embedding',\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** To access OpenAI embedding models you will need to create an OpenAI account, get an API key, and set `OPENAI_API_KEY` as an environment variable. You might also find it useful to experiment with another [embedding model](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) integration."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Q&A on Similarity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Langchain RetrievalQA` creates a question-answering (QA) chain using the above vector index as a retriever."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chains import RetrievalQA\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "vector_qa = RetrievalQA.from_chain_type(\n",
    "    llm=ChatOpenAI(),\n",
    "    chain_type=\"stuff\",\n",
    "    retriever=vector_index.as_retriever()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's ask '*which articles discuss how AI might affect our daily life?*':"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The articles that discuss how AI might affect our daily life are:\n",
      "\n",
      "1. **The Impact of AI on Employment: A Comprehensive Study**\n",
      "   *Abstract:* This study analyzes the potential effects of AI on various job sectors and suggests policy recommendations to mitigate negative impacts.\n",
      "\n",
      "2. **The Societal Implications of Advanced AI: A Multidisciplinary Analysis**\n",
      "   *Abstract:* Our study brings together experts from various fields to analyze the potential long-term impacts of advanced AI on society, economy, and culture.\n",
      "\n",
      "These two articles would provide insights into how AI could potentially impact our daily lives from different perspectives.\n"
     ]
    }
   ],
   "source": [
    "r = vector_qa.invoke(\n",
    "    {\"query\": \"which articles discuss how AI might affect our daily life? include the article titles and abstracts.\"}\n",
    ")\n",
    "print(r['result'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Traversing Knowledge Graphs for Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Knowledge graphs are excellent for making connections between entities, enabling the extraction of patterns and the discovery of new insights.\n",
    "\n",
    "This section demonstrates how to implement this process and integrate the results into an LLM pipeline using natural language queries."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Graph-Cypher-Chain w/ LangChain\n",
    "\n",
    "To construct expressive and efficient queries `Neo4j` users `Cypher`, a declarative query language inspired by SQL. `LangChain` provides the wrapper `GraphCypherQAChain`, an abstraction layer that allows querying graph databases using natural language, making it easier to integrate graph-based data retrieval into LLM pipelines.\n",
    "\n",
    "In practice, `GraphCypherQAChain`:\n",
    "- generates Cypher statements (queries for graph databases like Neo4j) from user input (natural language) applying in-context learning (prompt engineering),\n",
    "- executes said statements against a graph database, and \n",
    "- provides the results as context to ground the LLM responses on accurate, up-to-date information:\n",
    "\n",
    "**Note:** This implementation involves executing model-generated graph queries, which carries inherent risks such as unintended access or modification of sensitive data in the database. To mitigate these risks, ensure that your database connection permissions are as restricted as possible to meet the specific needs of your chain/agent. While this approach reduces risk, it does not eliminate it entirely."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chains import GraphCypherQAChain\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "graph.refresh_schema()\n",
    "\n",
    "cypher_chain = GraphCypherQAChain.from_llm(\n",
    "    cypher_llm = ChatOpenAI(temperature=0, model_name='gpt-4o'),\n",
    "    qa_llm = ChatOpenAI(temperature=0, model_name='gpt-4o'), \n",
    "    graph=graph,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Query Samples using Natural Language"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note in the following examples how the results from the cypher query execution are provided as context to the LLM:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### **\"*How many articles has published Emily Chen?*\"**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, our question '*How many articles has published Emily Chen?*' will be translated into the Cyper query:\n",
    "\n",
    "```\n",
    "MATCH (r:Researcher {name: \"Emily Chen\"})-[:PUBLISHED]->(a:Article)\n",
    "RETURN COUNT(a) AS numberOfArticles\n",
    "```\n",
    "\n",
    "which matches nodes labeled `Author` with the name 'Emily Chen' and traverses the `PUBLISHED` relationships to `Article` nodes. \n",
    "It then counts the number of `Article` nodes connected to 'Emily Chen':"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p>\n",
    "  <img src=\"https://raw.githubusercontent.com/dcarpintero/generative-ai-101/main/static/kg_sample_01.png\" width=\"40%\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
      "Generated Cypher:\n",
      "\u001b[32;1m\u001b[1;3mcypher\n",
      "MATCH (r:Researcher {name: \"Emily Chen\"})-[:PUBLISHED]->(a:Article)\n",
      "RETURN COUNT(a) AS numberOfArticles\n",
      "\u001b[0m\n",
      "Full Context:\n",
      "\u001b[32;1m\u001b[1;3m[{'numberOfArticles': 7}]\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'query': 'How many articles has published Emily Chen?',\n",
       " 'result': 'Emily Chen has published 7 articles.'}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# the answer should be '7'\n",
    "cypher_chain.invoke(\n",
    "    {\"query\": \"How many articles has published Emily Chen?\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### **\"*Are there any pair of researchers who have published more than three articles together?*\"**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, the query '*are there any pair of researchers who have published more than three articles together?*' results in the Cypher query:\n",
    "\n",
    "```\n",
    "MATCH (r1:Researcher)-[:PUBLISHED]->(a:Article)<-[:PUBLISHED]-(r2:Researcher)\n",
    "WHERE r1 <> r2\n",
    "WITH r1, r2, COUNT(a) AS sharedArticles\n",
    "WHERE sharedArticles > 3\n",
    "RETURN r1.name, r2.name, sharedArticles\n",
    "```\n",
    "\n",
    "which results in traversing from the `Researcher` nodes to the `PUBLISHED` relationship to find connected `Article` nodes, and then traversing back to find `Researchers` pairs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p>\n",
    "  <img src=\"https://raw.githubusercontent.com/dcarpintero/generative-ai-101/main/static/kg_sample_02.png\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
      "Generated Cypher:\n",
      "\u001b[32;1m\u001b[1;3mcypher\n",
      "MATCH (r1:Researcher)-[:PUBLISHED]->(a:Article)<-[:PUBLISHED]-(r2:Researcher)\n",
      "WHERE r1 <> r2\n",
      "WITH r1, r2, COUNT(a) AS sharedArticles\n",
      "WHERE sharedArticles > 3\n",
      "RETURN r1.name, r2.name, sharedArticles\n",
      "\u001b[0m\n",
      "Full Context:\n",
      "\u001b[32;1m\u001b[1;3m[{'r1.name': 'David Johnson', 'r2.name': 'Emily Chen', 'sharedArticles': 4}, {'r1.name': 'Robert Taylor', 'r2.name': 'Emily Chen', 'sharedArticles': 4}, {'r1.name': 'Emily Chen', 'r2.name': 'David Johnson', 'sharedArticles': 4}, {'r1.name': 'Emily Chen', 'r2.name': 'Robert Taylor', 'sharedArticles': 4}]\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'query': 'are there any pair of researchers who have published more than three articles together?',\n",
       " 'result': 'Yes, David Johnson and Emily Chen, as well as Robert Taylor and Emily Chen, have published more than three articles together.'}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# the answer should be David Johnson & Emily Chen, Robert Taylor & Emily Chen\n",
    "cypher_chain.invoke(\n",
    "    {\"query\": \"are there any pair of researchers who have published more than three articles together?\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### **\"*which researcher has collaborated with the most peers?*\"**\n",
    "\n",
    "Let's find out who is the researcher with most peers collaborations. \n",
    "Our query '*which researcher has collaborated with the most peers?*' results now in the Cyper:\n",
    "\n",
    "```\n",
    "MATCH (r:Researcher)-[:PUBLISHED]->(:Article)<-[:PUBLISHED]-(peer:Researcher)\n",
    "WITH r, COUNT(DISTINCT peer) AS peerCount\n",
    "RETURN r.name AS researcher, peerCount\n",
    "ORDER BY peerCount DESC\n",
    "LIMIT 1\n",
    "```\n",
    "\n",
    "Here, we need to start from all `Researcher` nodes and traverse their `PUBLISHED` relationships to find connected `Article` nodes. For each `Article` node, Neo4j then traverses back to find other `Researcher` nodes (peer) who have also published the same article."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p>\n",
    "  <img src=\"https://raw.githubusercontent.com/dcarpintero/generative-ai-101/main/static/kg_sample_03.png\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
      "Generated Cypher:\n",
      "\u001b[32;1m\u001b[1;3mcypher\n",
      "MATCH (r1:Researcher)-[:PUBLISHED]->(:Article)<-[:PUBLISHED]-(r2:Researcher)\n",
      "WHERE r1 <> r2\n",
      "WITH r1, COUNT(DISTINCT r2) AS collaborators\n",
      "RETURN r1.name AS researcher, collaborators\n",
      "ORDER BY collaborators DESC\n",
      "LIMIT 1\n",
      "\u001b[0m\n",
      "Full Context:\n",
      "\u001b[32;1m\u001b[1;3m[{'researcher': 'David Johnson', 'collaborators': 6}]\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'query': 'Which researcher has collaborated with the most peers?',\n",
       " 'result': 'David Johnson has collaborated with 6 peers.'}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# the answer should be 'David Johnson'\n",
    "cypher_chain.invoke(\n",
    "    {\"query\": \"Which researcher has collaborated with the most peers?\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
